{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "02162.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/02162.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "geNxJNM5lSmb",
        "colab_type": "text"
      },
      "source": [
        "## 5.11 残差网络 ( ResNet )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1cxdqJrklmbF",
        "colab_type": "text"
      },
      "source": [
        "### 5.11.2 残差块"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KfiVeqf-ltFQ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "from torch import nn, optim \n",
        "import torch.nn.functional as F \n",
        "import d2l \n",
        "import time \n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "\n",
        "\n",
        "class Residual(nn.Module):\n",
        "    def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):\n",
        "        super(Residual, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)\n",
        "        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)\n",
        "        if use_1x1conv:\n",
        "            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)\n",
        "        else:\n",
        "            self.conv3 = None \n",
        "        self.bn1 = nn.BatchNorm2d(out_channels)\n",
        "        self.bn2 = nn.BatchNorm2d(out_channels)\n",
        "\n",
        "    def forward(self, X):\n",
        "        Y = F.relu(self.bn1(self.conv1(X)))\n",
        "        Y = self.bn2(self.conv2(Y))\n",
        "        if self.conv3:\n",
        "            X = self.conv3(X)\n",
        "        return F.relu(X + Y)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "D0sri5SpFOQg",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "de13441b-27cc-41d1-af23-fe2ff96f708b"
      },
      "source": [
        "blk = Residual(3, 3)\n",
        "X = torch.rand((4, 3, 6, 6))\n",
        "blk(X).shape"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "torch.Size([4, 3, 6, 6])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "z_PsIRfEFW8Q",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "a3b7c48d-3d03-4a45-a8be-479e2b64bf97"
      },
      "source": [
        "blk = Residual(3, 6, use_1x1conv=True, stride=2)\n",
        "blk(X).shape"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "torch.Size([4, 6, 3, 3])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gkZcbCqxFfc5",
        "colab_type": "text"
      },
      "source": [
        "### 5.11.2 ResNet 模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8PyNm8mdFo7z",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "net = nn.Sequential(\n",
        "    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3), \n",
        "    nn.BatchNorm2d(64), \n",
        "    nn.ReLU(), \n",
        "    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
        ")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bGTDoGgrF5pc",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def resnet_block(in_channels, out_channels, num_residuals, first_block=False):\n",
        "    if first_block:\n",
        "        assert in_channels == out_channels \n",
        "    blk = []\n",
        "    for i in range(num_residuals):\n",
        "        if i == 0 and not first_block:\n",
        "            blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))\n",
        "        else:\n",
        "            blk.append(Residual(out_channels, out_channels))\n",
        "    return nn.Sequential(*blk)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Eigf-R62GfLa",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "net.add_module('resnet_block1', resnet_block(64, 64, 2, first_block=True))\n",
        "net.add_module('resnet_block2', resnet_block(64, 128, 2))\n",
        "net.add_module('resnet_block3', resnet_block(128, 256, 2))\n",
        "net.add_module('resnet_block4', resnet_block(256, 512, 2))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Mt8019xUHG7X",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "net.add_module('global_avg_pool', d2l.GlobalAvgPool2d())\n",
        "net.add_module('fc', nn.Sequential(d2l.FlattenLayer(), nn.Linear(512, 10)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "H71jIZRcHoW-",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 185
        },
        "outputId": "d09898f8-3f92-4cc9-c427-a757802ea3bc"
      },
      "source": [
        "X = torch.rand((1, 1, 224, 224))\n",
        "for name, layer in net.named_children():\n",
        "    X = layer(X)\n",
        "    print(name, 'output shape:\\t', X.shape)"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "0 output shape:\t torch.Size([1, 64, 112, 112])\n",
            "1 output shape:\t torch.Size([1, 64, 112, 112])\n",
            "2 output shape:\t torch.Size([1, 64, 112, 112])\n",
            "3 output shape:\t torch.Size([1, 64, 56, 56])\n",
            "resnet_block1 output shape:\t torch.Size([1, 64, 56, 56])\n",
            "resnet_block2 output shape:\t torch.Size([1, 128, 28, 28])\n",
            "resnet_block3 output shape:\t torch.Size([1, 256, 14, 14])\n",
            "resnet_block4 output shape:\t torch.Size([1, 512, 7, 7])\n",
            "global_avg_pool output shape:\t torch.Size([1, 512, 1, 1])\n",
            "fc output shape:\t torch.Size([1, 10])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iiNYYW18H5-O",
        "colab_type": "text"
      },
      "source": [
        "### 5.11.3 获取数据和训练模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-z732pTKIB00",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 118
        },
        "outputId": "488bb525-3b8d-4b8a-8d60-b5f8952d2c0e"
      },
      "source": [
        "batch_size = 256 \n",
        "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)\n",
        "\n",
        "lr, num_epochs = 0.001, 5 \n",
        "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
        "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "training on  cuda\n",
            "epoch 1, loss 0.4118, train acc 0.850, test acc 0.876, time 29.0 sec\n",
            "epoch 2, loss 0.2536, train acc 0.907, test acc 0.898, time 28.9 sec\n",
            "epoch 3, loss 0.2148, train acc 0.921, test acc 0.911, time 29.0 sec\n",
            "epoch 4, loss 0.1858, train acc 0.931, test acc 0.896, time 28.8 sec\n",
            "epoch 5, loss 0.1631, train acc 0.939, test acc 0.920, time 28.7 sec\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}